{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "80607b8d",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71e4438b",
   "metadata": {},
   "source": [
    "## Scene detection and fold split\n",
    "\n",
    "This tutorial shows how to detect different scenes and then split the fols to make sure each fold contain similar number of vides from similar number of scenes."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2bf31e5d",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e84de713",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import imagehash\" || pip -q install imagehash\n",
    "!python -c \"import iterstrat\" || pip install -q iterative-stratification\n",
    "!python -c \"import monai\" || pip install -q \"monai-weekly[pillow]\"\n",
    "!python -c \"import matplotlib\" || pip install -q matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88476cba",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8bfd4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "from glob import glob\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
    "from PIL import Image\n",
    "import imagehash\n",
    "import multiprocessing\n",
    "from monai.config import print_config\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c78880b1",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a34fdeba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(24695, 16)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# replace the dir into your local dir\n",
    "\n",
    "df = pd.read_csv(\"/raid/surg/_release/training_data/labels.csv\")[[\"clip_name\", \"tools_present\"]]\n",
    "img_dir = \"/raid/surg/image640_blur/\"\n",
    "cpu_ct = 32\n",
    "\n",
    "\n",
    "def split_label(s):\n",
    "    return [x.strip(\" \") for x in s[1:-1].split(\",\")]\n",
    "\n",
    "\n",
    "label_lst = df.tools_present.apply(split_label).values.tolist()\n",
    "label_lst = [x for xs in label_lst for x in xs]\n",
    "\n",
    "labels = pd.Series(label_lst).value_counts().index.values[1:]\n",
    "for lb in labels:\n",
    "    df[lb] = df.tools_present.str.count(lb)\n",
    "\n",
    "labels = df.columns.values[2:]\n",
    "# uncomment the following line to produce the train.csv file\n",
    "# df.to_csv('../train.csv', index=False)\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "40043550",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(765000, 18)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_img = pd.DataFrame([os.path.basename(x) for x in sorted(glob(os.path.join(img_dir, \"*.jpg\")))], columns=[\"img_path\"])\n",
    "df_img[\"clip_name\"] = df_img.img_path.apply(lambda x: x[:11])\n",
    "\n",
    "df = df.merge(df_img, on=\"clip_name\", how=\"left\")\n",
    "df = df[pd.notna(df.img_path)]\n",
    "\n",
    "df[\"frame\"] = df.img_path.apply(lambda x: int(x[:-4].split(\"_\")[-1]))\n",
    "df = df.sort_values([\"clip_name\", \"frame\"]).reset_index(drop=True)\n",
    "\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "400c8df5",
   "metadata": {},
   "source": [
    "## Scene detection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4859b71",
   "metadata": {},
   "source": [
    "In this dataset, a number of (sometimes up dozens of) consecutive videos are from the same operation, or scene. Therefore, it is important to identify them and put videos from the same scene into the same fold when making fold splits, in order to prevent leakage in local validation.\n",
    "\n",
    "The way we detect scenes is to compare the image hashes of the last frame of a video against the first frame of the next video. If the similarity is above a threshold, they belong to the same scene. Otherwise, the next video is the start of the next scene"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "898f678a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(49370, 19)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[\"last\"] = df.frame.diff(-1)\n",
    "dfl = df[(df.frame == 0) | (df[\"last\"] > 0)].iloc[:-1]\n",
    "dfl = dfl.reset_index(drop=True)\n",
    "dfl.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d36d581a",
   "metadata": {},
   "outputs": [],
   "source": [
    "funcs = [\n",
    "    imagehash.average_hash,\n",
    "    imagehash.phash,\n",
    "    imagehash.dhash,\n",
    "    imagehash.whash,\n",
    "]\n",
    "\n",
    "\n",
    "def get_hash(img_path):\n",
    "    image = Image.open(f\"{img_dir}/{img_path}\")\n",
    "    return np.array([f(image).hash for f in funcs]).reshape(256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "48ac948c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 49370/49370 [01:28<00:00, 559.23it/s]\n"
     ]
    }
   ],
   "source": [
    "with multiprocessing.Pool(cpu_ct) as pool:\n",
    "    imap = pool.imap(get_hash, dfl.img_path.values)\n",
    "    hashes = list(tqdm(imap, total=len(dfl)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9e075dda",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(49369, 256)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hashes = np.stack(hashes)[:-1, :]\n",
    "hashes.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3c262bbe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(24684,)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hash_diffs = (hashes[1::2, :] == hashes[2::2, :]).sum(1)\n",
    "hash_diffs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bc653e36",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfl[\"hash_sim\"] = 0\n",
    "dfl.loc[1 : len(dfl) - 2 : 2, \"hash_sim\"] = hash_diffs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b5a1d3a",
   "metadata": {},
   "source": [
    "Through visual inspections, **170** is a good cutoff hash similary difference.\n",
    "\n",
    "When the hash similarity is larger than or equal to 170, the two consecutive videos belong to the same scene.\n",
    "\n",
    "When the hash similarity is smaller than 170, the second video starts a new scene."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1a04b109",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1068, 4)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tmp = dfl[(dfl.frame != 0) & (dfl.hash_sim < 170)][[\"clip_name\", \"img_path\", \"hash_sim\"]].copy()\n",
    "tmp[\"EOS\"] = True\n",
    "tmp.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "31c837e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(24685, 19)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_scene = (\n",
    "    dfl[dfl.frame == 0]\n",
    "    .drop(columns=[\"frame\", \"last\", \"img_path\", \"hash_sim\"])\n",
    "    .merge(tmp[[\"clip_name\", \"EOS\"]], on=\"clip_name\", how=\"left\")\n",
    ")\n",
    "df_scene[\"EOS\"] = df_scene[\"EOS\"].fillna(0).astype(int)\n",
    "df_scene[\"SOS\"] = df_scene[\"EOS\"].shift(1).fillna(1).astype(int)\n",
    "df_scene[\"scene\"] = df_scene[\"SOS\"].cumsum() - 1\n",
    "df_scene.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fb46c1e",
   "metadata": {},
   "source": [
    "The number of scenes that each tool appears in:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3dda6cba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "needle driver                   572\n",
      "cadiere forceps                 665\n",
      "bipolar forceps                 597\n",
      "monopolar curved scissors       560\n",
      "grasping retractor              223\n",
      "prograsp forceps                240\n",
      "force bipolar                   64\n",
      "vessel sealer                   93\n",
      "permanent cautery hook/spatula  56\n",
      "clip applier                    168\n",
      "tip-up fenestrated grasper      14\n",
      "stapler                         21\n",
      "bipolar dissector               1\n",
      "suction irrigator               4\n"
     ]
    }
   ],
   "source": [
    "for lb in labels:\n",
    "    print(f\"{lb:30}  {df_scene[df_scene[lb]>0].scene.nunique()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9be17bf",
   "metadata": {},
   "source": [
    "## Splitting folds based on scene number using iterative stratification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b0bf7bec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(24695, 17)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(\"../train.csv\")\n",
    "labels = df.columns.values[2:]\n",
    "\n",
    "df = df.merge(df_scene[[\"clip_name\", \"scene\"]], how=\"left\", on=\"clip_name\")\n",
    "df[\"scene\"] = df[\"scene\"].fillna(df_scene.scene.nunique()).astype(int)\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd171060",
   "metadata": {},
   "source": [
    "There are 1069 unique scenes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "dfcec05f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1069"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.scene.nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a60cc805",
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp = df.groupby(\"scene\")[labels].max().clip(0, 1).reset_index()\n",
    "\n",
    "X = tmp[labels].values\n",
    "y = tmp[labels].values\n",
    "tmp[\"fold\"] = -1\n",
    "\n",
    "mskf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=1)\n",
    "\n",
    "for i, (_, test_index) in enumerate(mskf.split(X, y)):\n",
    "    tmp.loc[test_index, \"fold\"] = i"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "678b6d59",
   "metadata": {},
   "source": [
    "Numbers of scenes in each fold are evenly distributed for all the tools except `bipolar dissector` (all the videos containing `bipolar dissector` are in the same scene):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bee22438",
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp.groupby(\"fold\")[labels].sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e70c1142",
   "metadata": {},
   "source": [
    "Numbers of videos in each fold are also nearly evenly distributed except for `bipolar dissector` (random_state can be adjusted in the `MultilabelStratifiedKFold` call above for a different distribution):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eabf67a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.merge(tmp[[\"scene\", \"fold\"]], on=\"scene\", how=\"left\")\n",
    "df.groupby(\"fold\")[labels].sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7c0fe3f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# df.to_csv('../train_fold_balanced.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
